In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from pathlib import Path
from functools import partial
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import re
import random

Do preprocessing

In [3]:
#/hpf/largeprojects/MICe/mdagys/Cnp-GFP_Study/2019-06-10_labelled/raw
raw_dir = Path("raw")
raws = raw_dir.ls()
images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name])
labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name])
# D-R_Z were the initial ones to be labelled, kinda more sloppy.
# images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name and "D-R_Z" not in raw_path.name])
# labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name and "D-R_Z" not in raw_path.name])

processed_dir = Path("processed")
l=224
In [ ]:
random.seed(23)
empty = 0
popu = 0
cutoff=1

for image_path,label_path in zip(images,labels):
    image = cv.imread(image_path.as_posix(), cv.COLOR_BGR2GRAY)
    label = cv.imread(label_path.as_posix(), cv.COLOR_BGR2GRAY)

    if image.shape != label.shape:
        raise ValueError(image_path.as_posix() + label_path.as_posix())
    i_max = image.shape[0]//l
    j_max = image.shape[1]//l

# If the cells were labelled as 255, or something else mistakenly, instead of 1.
    label[label!=0]=1

    for i in range(i_max):
        for j in range(j_max):
            cropped_image = image[l*i:l*(i+1), l*j:l*(j+1)]
            cropped_label = label[l*i:l*(i+1), l*j:l*(j+1)]

            if (cropped_label!=0).any():
                popu+=1
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + label_path.suffix)
            else:
                empty+=1
                if (random.random() < cutoff):
                    continue
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + label_path.suffix)

            cv.imwrite(cropped_image_path.as_posix(), cropped_image)
            cv.imwrite(cropped_label_path.as_posix(), cropped_label)
In [ ]:
print(popu)
print(empty)

Train NN

In [4]:
torch.cuda.set_device(0)
In [5]:
codes = ["NOT-CELL", "CELL"]
bs = 4
#bs=16 and l=224 will use ~7300MiB for resnet34  before unfreezing
#bs=4 and l=224 use ~11500MiB for resnet50 before unfreezing
In [6]:
transforms = get_transforms(
    do_flip = True,
    flip_vert = True,
    max_zoom = 1, #consider
    max_rotate = 45,
    max_lighting = None,
    max_warp = None,
    p_affine = 0.75,
    p_lighting = 0.75)
In [7]:
get_label_from_image = lambda path: re.sub(r'_image_', '_label_', path.as_posix())

src = (
    SegmentationItemList.from_folder(processed_dir)
    .filter_by_func(lambda fname:
                    'image' in Path(fname).name and "empty" not in Path(fname).name)
    .split_by_rand_pct(valid_pct=0.20, seed=1)
    .label_from_func(get_label_from_image, classes=codes)
)
data = (
    src.transform(transforms, tfm_y=True)
    .databunch(bs=bs)
    .normalize(imagenet_stats)
)
In [8]:
data.show_batch(2, figsize=(10,7))
In [8]:
# models.resnet34
model_path = Path("../../models")
learn = unet_learner(data, models.resnet50, metrics=partial(dice, iou=True))
learn.loss_func = CrossEntropyFlat(axis=1, weight = torch.Tensor([1,1]).cuda())
In [10]:
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [9]:
lr = 2e-3
learn.fit_one_cycle(15, lr)
epoch train_loss valid_loss dice time
0 0.020690 0.017666 0.054046 04:17
1 0.017027 0.014771 0.229396 04:00
2 0.021555 0.016435 0.278650 03:58
3 0.214328 0.017858 0.185597 04:01
4 0.017524 0.016227 0.363482 04:00
5 0.017529 0.015227 0.369190 03:59
6 0.013834 0.017169 0.267608 04:02
7 0.013918 0.013550 0.421578 03:59
8 0.015694 0.013240 0.366751 03:59
9 0.013832 0.012862 0.408747 04:01
10 0.013238 0.012563 0.423296 04:00
11 0.013704 0.012298 0.421512 04:12
12 0.012952 0.012179 0.392708 04:01
13 0.013106 0.012023 0.418428 04:02
14 0.013236 0.012036 0.412774 04:01
In [22]:
learn.save(model_path/"2019-07-02_RESNET50_IOU0.41_1stage")
In [ ]:
!jupyter nbconvert gfp-cnp-train.ipynb --to html --output nbs/2019-06-26_RESNET50_IOU0.41_1stage
In [ ]:
learn.load(model_path/"2019-07-02_RESNET50_IOU0.41_1stage");
In [ ]:
learn.freeze_to(-2)
In [ ]:
lr_find(learn)
learn.recorder.plot()
In [ ]:
lr=1e-5
lrs = slice(lr/1000,lr/10)
learn.fit_one_cycle(20, lrs)
In [ ]:
learn.save(models_path/"2019-06-14_RESNET34_IOU0.25_2stage")
In [ ]:
learn.export(file = models_path/"2019-06-14_RESNET34_IOU0.25_2stage.pkl")

Check

In [ ]:
print(learn.data.valid_ds.__len__()) #list of N
print(learn.data.valid_ds[0]) #tuple of input image and segment
print(learn.data.valid_ds[0][1])
# print(learn.data.valid_ds.__len__())
# type(learn.data.valid_ds[0][0])
In [12]:
# preds = learn.get_preds(with_loss=True)
preds = learn.get_preds()
In [ ]:
print(len(preds)) # tuple of list of probs and targets
print(preds[0].shape) #predictions
print(preds[0][0].shape) #probabilities for each label
print(learn.data.classes) #what is each label
print(preds[0][0][0].shape) #probabilities for label 0
# for i in range(0,N):
#     print(torch.max(preds[0][i][1]))

# Image(preds[1][0]).show()
In [13]:
if learn.data.valid_ds.__len__() == preds[1].shape[0]:
    N = learn.data.valid_ds.__len__()
else:
    raise ValueError()

xs = [learn.data.valid_ds[i][0] for i in range(N)]
ys = [learn.data.valid_ds[i][1] for i in range(N)]
p0s = [Image(preds[0][i][0]) for i in range(N)]
p1s = [Image(preds[0][i][1]) for i in range(N)]
argmax = [Image(preds[0][i].argmax(dim=0)) for i in range(N)]
In [ ]:
print(xs[0].px.shape)
print(ys[0].px.shape)
print(p0s[0].px.shape)
print(p1s[0].px.shape)
In [14]:
ncol = 3
nrow = N//ncol + 1
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
#     plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Oranges", alpha=0.5)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
    plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [19]:
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
    plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Greys", alpha=1)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [17]:
learn.show_results(rows=16, ds_type=DatasetType.Train)
In [20]:
learn.show_results(rows=16)
In [ ]:
# lrs = slice(lr/400,lr/4)
In [ ]:
# learn.fit_one_cycle(15, lrs, pct_start=0.8)
In [ ]:
# learn.save('stage-2');
In [ ]:
# learn.show_results(rows=6, ds_type=DatasetType.Train)